/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.CteReferenceNode;
import com.facebook.presto.spi.plan.DeleteNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IndexJoinNode;
import com.facebook.presto.spi.plan.IndexSourceNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.MaterializedViewScanNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isEmptyConnectorOptimizerEnabled;
import static com.facebook.presto.SystemSessionProperties.isIncludeValuesNodeInConnectorOptimizer;
import static com.facebook.presto.common.RuntimeUnit.NANO;
import static com.facebook.presto.sql.OptimizerRuntimeTrackUtil.getOptimizerNameForLog;
import static com.facebook.presto.sql.OptimizerRuntimeTrackUtil.trackOptimizerRuntime;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

public class ApplyConnectorOptimization
        implements PlanOptimizer
{
    static final Set<Class<? extends PlanNode>> CONNECTOR_ACCESSIBLE_PLAN_NODES = ImmutableSet.of(
            CteProducerNode.class,
            CteConsumerNode.class,
            CteReferenceNode.class,
            DistinctLimitNode.class,
            FilterNode.class,
            TableScanNode.class,
            IndexSourceNode.class,
            LimitNode.class,
            SortNode.class,
            TopNNode.class,
            ValuesNode.class,
            ProjectNode.class,
            AggregationNode.class,
            MarkDistinctNode.class,
            MaterializedViewScanNode.class,
            UnionNode.class,
            IntersectNode.class,
            ExceptNode.class,
            SemiJoinNode.class,
            JoinNode.class,
            IndexJoinNode.class,
            UnnestNode.class,
            TableWriterNode.class,
            TableFinishNode.class,
            DeleteNode.class);

    // for a leaf node that does not belong to any connector (e.g., ValuesNode)
    private static final ConnectorId EMPTY_CONNECTOR_ID = new ConnectorId("$internal$ApplyConnectorOptimization_EMPTY_CONNECTOR");

    private final Supplier<Map<ConnectorId, Set<ConnectorPlanOptimizer>>> connectorOptimizersSupplier;

    public ApplyConnectorOptimization(Supplier<Map<ConnectorId, Set<ConnectorPlanOptimizer>>> connectorOptimizersSupplier)
    {
        this.connectorOptimizersSupplier = requireNonNull(connectorOptimizersSupplier, "connectorOptimizersSupplier is null");
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");
        requireNonNull(variableAllocator, "variableAllocator is null");
        requireNonNull(idAllocator, "idAllocator is null");

        boolean enableVerboseRuntimeStats = SystemSessionProperties.isVerboseRuntimeStatsEnabled(session);
        Map<ConnectorId, Set<ConnectorPlanOptimizer>> connectorOptimizers = connectorOptimizersSupplier.get();
        if (connectorOptimizers.isEmpty()) {
            return PlanOptimizerResult.optimizerResult(plan, false);
        }

        // retrieve all the connectors
        ImmutableSet.Builder<ConnectorId> connectorIds = ImmutableSet.builder();
        getAllConnectorIds(plan, connectorIds);
        Set<ConnectorId> connectorIdSet = connectorIds.build();

        // for each connector, retrieve the set of subplans to optimize
        // TODO: what if a new connector is added by an existing one
        // There are cases (e.g., query federation) where a connector C1 needs to
        // create a UNION_ALL to federate data sources from both C1 and C2 (regardless of the classloader issue).
        // For such case, it is dangerous to re-calculate the "max closure" given the fixpoint property will be broken.
        // In order to preserve the fixpoint, we will "pretend" the newly added C2 table scan is part of C1's job to maintain.
        for (ConnectorId connectorId : connectorIdSet) {
            Set<ConnectorPlanOptimizer> optimizers;
            if (isEmptyConnectorOptimizerEnabled(session) && connectorIdSet.stream()
                    .allMatch(x -> x.equals(EMPTY_CONNECTOR_ID)) && session.getCatalog().isPresent()) {
                ConnectorId queryConnectorId = new ConnectorId(session.getCatalog().get());
                optimizers = connectorOptimizers.get(queryConnectorId) == null ? null
                        : connectorOptimizers.get(queryConnectorId).stream()
                        .filter(x -> x.getSupportedConnectorIds().size() == 1
                                && x.getSupportedConnectorIds().get(0).equals(EMPTY_CONNECTOR_ID))
                        .collect(
                                toImmutableSet());
            }
            else {
                optimizers = connectorOptimizers.get(connectorId);
            }
            if (optimizers == null || optimizers.isEmpty()) {
                continue;
            }

            ImmutableMap.Builder<List<ConnectorId>, Set<ConnectorPlanOptimizer>> optimizersWithConnectorRange = ImmutableMap.builder();
            List<ConnectorId> currentConnectors = null;
            ImmutableSet.Builder<ConnectorPlanOptimizer> currentGroup = null;
            for (ConnectorPlanOptimizer optimizer : optimizers) {
                List<ConnectorId> supportedConnectors = optimizer.getSupportedConnectorIds().isEmpty()
                        ? ImmutableList.of(connectorId)
                        : optimizer.getSupportedConnectorIds();

                if (!supportedConnectors.equals(currentConnectors)) {
                    if (currentGroup != null) {
                        optimizersWithConnectorRange.put(currentConnectors, currentGroup.build());
                    }
                    currentConnectors = supportedConnectors;
                    currentGroup = ImmutableSet.builder();
                }
                currentGroup.add(optimizer);
            }
            optimizersWithConnectorRange.put(currentConnectors, currentGroup.build());

            for (Map.Entry<List<ConnectorId>, Set<ConnectorPlanOptimizer>> entry : optimizersWithConnectorRange.build().entrySet()) {
                // keep track of changed nodes; the keys are original nodes and the values are the new nodes
                Map<PlanNode, PlanNode> updates = new HashMap<>();

                ImmutableMap.Builder<PlanNode, ConnectorPlanNodeContext> contextMapBuilder = ImmutableMap.builder();
                buildConnectorPlanNodeContext(plan, null, contextMapBuilder);
                Map<PlanNode, ConnectorPlanNodeContext> contextMap = contextMapBuilder.build();

                // process connector optimizers
                for (PlanNode node : contextMap.keySet()) {
                    // For a subtree with root `node` to be a max closure, the following conditions must hold:
                    //    * The subtree with root `node` is a closure.
                    //    * `node` has no parent, or the subtree with root as `node`'s parent is not a closure.
                    ConnectorPlanNodeContext context = contextMap.get(node);
                    if (!context.isClosure(connectorId, session, entry.getKey()) ||
                            !context.getParent().isPresent() ||
                            contextMap.get(context.getParent().get()).isClosure(connectorId, session, entry.getKey())) {
                        continue;
                    }

                    PlanNode newNode = node;

                    // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here)
                    for (ConnectorPlanOptimizer optimizer : entry.getValue()) {
                        long start = System.nanoTime();
                        ConnectorSession connectorSession = session.toConnectorSession(connectorId);
                        if (isEmptyConnectorOptimizerEnabled(session) && connectorId.equals(EMPTY_CONNECTOR_ID) && session.getCatalog().isPresent()) {
                            connectorSession = session.toConnectorSession(new ConnectorId(session.getCatalog().get()));
                        }
                        checkState(connectorSession.getConnectorId().isPresent());
                        newNode = optimizer.optimize(newNode, connectorSession, variableAllocator, idAllocator);
                        if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) {
                            session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start);
                        }
                    }

                    if (node != newNode) {
                        // the optimizer has allocated a new PlanNode
                        checkState(
                                containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()),
                                "the connector optimizer from %s returns a node that does not cover all output before optimization",
                                connectorId);

                        updates.put(node, newNode);
                    }
                }
                // up to this point, we have a set of updated nodes; need to recursively update their parents

                // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm)
                // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes
                Queue<PlanNode> originalNodes = new LinkedList<>(updates.keySet());
                while (!originalNodes.isEmpty()) {
                    PlanNode originalNode = originalNodes.poll();

                    if (!contextMap.get(originalNode).getParent().isPresent()) {
                        // originalNode must be the root; update the plan
                        plan = updates.get(originalNode);
                        continue;
                    }

                    PlanNode originalParent = contextMap.get(originalNode).getParent().get();

                    // need to create a new parent given the child has changed; the new parent needs to point to the new child.
                    // if a node has been updated, it will occur in `updates`; otherwise, just use the original node
                    ImmutableList.Builder<PlanNode> newChildren = ImmutableList.builder();
                    originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child)));
                    PlanNode newParent = originalParent.replaceChildren(newChildren.build());

                    // mark the new parent as updated
                    updates.put(originalParent, newParent);

                    // enqueue the parent node in order to recursively update its ancestors
                    originalNodes.add(originalParent);
                }
            }
        }

        return PlanOptimizerResult.optimizerResult(plan, true);
    }

    private static void getAllConnectorIds(PlanNode node, ImmutableSet.Builder<ConnectorId> builder)
    {
        if (node.getSources().isEmpty()) {
            if (node instanceof TableScanNode) {
                builder.add(((TableScanNode) node).getTable().getConnectorId());
            }
            else if (node instanceof IndexSourceNode) {
                builder.add(((IndexSourceNode) node).getTableHandle().getConnectorId());
            }
            else {
                builder.add(EMPTY_CONNECTOR_ID);
            }
            return;
        }

        for (PlanNode child : node.getSources()) {
            getAllConnectorIds(child, builder);
        }
    }

    private static ConnectorPlanNodeContext buildConnectorPlanNodeContext(
            PlanNode node,
            PlanNode parent,
            ImmutableMap.Builder<PlanNode, ConnectorPlanNodeContext> contextBuilder)
    {
        Set<ConnectorId> connectorIds;
        Set<Class<? extends PlanNode>> planNodeTypes;

        if (node.getSources().isEmpty()) {
            if (node instanceof TableScanNode) {
                connectorIds = ImmutableSet.of(((TableScanNode) node).getTable().getConnectorId());
                planNodeTypes = ImmutableSet.of(TableScanNode.class);
            }
            else if (node instanceof IndexSourceNode) {
                connectorIds = ImmutableSet.of(((IndexSourceNode) node).getTableHandle().getConnectorId());
                planNodeTypes = ImmutableSet.of(IndexSourceNode.class);
            }
            else {
                connectorIds = ImmutableSet.of(EMPTY_CONNECTOR_ID);
                planNodeTypes = ImmutableSet.of(node.getClass());
            }
        }
        else {
            connectorIds = new HashSet<>();
            planNodeTypes = new HashSet<>();

            for (PlanNode child : node.getSources()) {
                ConnectorPlanNodeContext childContext = buildConnectorPlanNodeContext(child, node, contextBuilder);
                connectorIds.addAll(childContext.getReachableConnectors());
                planNodeTypes.addAll(childContext.getReachablePlanNodeTypes());
            }
            planNodeTypes.add(node.getClass());
        }

        ConnectorPlanNodeContext connectorPlanNodeContext = new ConnectorPlanNodeContext(
                parent,
                connectorIds,
                planNodeTypes);

        contextBuilder.put(node, connectorPlanNodeContext);
        return connectorPlanNodeContext;
    }

    /**
     * Extra information needed for a plan node
     */
    private static final class ConnectorPlanNodeContext
    {
        private final PlanNode parent;
        private final Set<ConnectorId> reachableConnectors;
        private final Set<Class<? extends PlanNode>> reachablePlanNodeTypes;

        ConnectorPlanNodeContext(PlanNode parent, Set<ConnectorId> reachableConnectors, Set<Class<? extends PlanNode>> reachablePlanNodeTypes)
        {
            this.parent = parent;
            this.reachableConnectors = requireNonNull(reachableConnectors, "reachableConnectors is null");
            this.reachablePlanNodeTypes = requireNonNull(reachablePlanNodeTypes, "reachablePlanNodeTypes is null");
            checkArgument(!reachableConnectors.isEmpty(), "encountered a PlanNode that reaches no connector");
            checkArgument(!reachablePlanNodeTypes.isEmpty(), "encountered a PlanNode that reaches no plan node");
        }

        Optional<PlanNode> getParent()
        {
            return Optional.ofNullable(parent);
        }

        public Set<ConnectorId> getReachableConnectors()
        {
            return reachableConnectors;
        }

        public Set<Class<? extends PlanNode>> getReachablePlanNodeTypes()
        {
            return reachablePlanNodeTypes;
        }

        boolean isClosure(ConnectorId connectorId, Session session, List<ConnectorId> supportedConnectorId)
        {
            if (isEmptyConnectorOptimizerEnabled(session) && reachableConnectors.stream().allMatch(x -> x.equals(EMPTY_CONNECTOR_ID)) && supportedConnectorId.size() == 1 && supportedConnectorId.get(0).equals(EMPTY_CONNECTOR_ID)) {
                return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes);
            }
            // check if all children can reach the only connector
            boolean includeValuesNode = isIncludeValuesNodeInConnectorOptimizer(session);
            Set<ConnectorId> connectorIds = includeValuesNode ? reachableConnectors.stream().filter(x -> !x.equals(EMPTY_CONNECTOR_ID)).collect(toImmutableSet()) : reachableConnectors;
            if (connectorIds.contains(connectorId) && new HashSet<>(supportedConnectorId).containsAll(connectorIds) && supportedConnectorId.size() == connectorIds.size()) {
                // check if all children are accessible by connectors
                return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes);
            }
            return false;
        }
    }

    private static <T> boolean containsAll(Set<T> container, Collection<T> test)
    {
        for (T element : test) {
            if (!container.contains(element)) {
                return false;
            }
        }
        return true;
    }
}
